{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Out-of-Sample (OOS) Embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Suppose we've embedded the nodes of a graph into Euclidean space using Adjacency Spectral Embedding (ASE) or Laplacian Spectral Embeding (LSE).  \n",
    "Then, suppose we gain access to new nodes not seen in the original graph. We sometimes wish to determine their latent positions without the computationally-expensive task of re-embedding an entirely new adjacency matrix.\n",
    "\n",
    "More formally, suppose we have computed the embedding $\\hat{X} \\in \\textbf{R}^{n \\times d}$ from some adjacency matrix $A \\in \\textbf{R}^{n \\times n}$.  \n",
    "Suppose we then obtain some new vertex with adjacency vector $w \\in \\textbf{R}^n$ or new vertices with \"adjacency\" matrix $W \\in \\textbf{R}^{m \\times n}$, with $m$ the number of new vertices. We wish to estimate the latent positions for these new vertices.\n",
    "\n",
    "Here, an \"adjacency vector\" $w$ is a vector with $n$ elements, $n$ being the number of in-sample vertices, and a 1 in the $i_{th}$ position if the out-of-sample vertex has an edge with in-sample vertex $i$ in the unweighted case.\n",
    "\n",
    "$W \\in \\textbf{R}^{m \\times n}$ is a matrix with each row being an adjacency vector, for $m$ out-of-sample vertices.\n",
    "\n",
    "We can obtain this estimation with ASE's or LSE's `transform` method.  \n",
    "Running through the Adjacency Spectral Embedding tutorial is recommended prior to this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "from numpy.random import normal, poisson\n",
    "\n",
    "from graspologic.simulations import sbm\n",
    "from graspologic.embed import AdjacencySpectralEmbed as ASE\n",
    "from graspologic.embed import LaplacianSpectralEmbed as LSE\n",
    "from graspologic.plot import heatmap, pairplot\n",
    "from graspologic.utils import remove_vertices\n",
    "\n",
    "np.random.seed(1234)\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Undirected out-of-sample prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we embed an undirected two-block stochastic block model with ASE. We then use its transform method to find an out-of-sample prediction for both a single vertex and multiple vertices.  \n",
    "\n",
    "We begin by generating data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate parameters\n",
    "nodes_per_community = 100\n",
    "P = np.array([[0.8, 0.2],\n",
    "              [0.2, 0.8]])\n",
    "\n",
    "# Generate an undirected Stochastic Block Model (SBM)\n",
    "undirected, labels_ = sbm(2*[nodes_per_community], P, return_labels=True)\n",
    "labels = list(labels_)\n",
    "\n",
    "# Grab out-of-sample vertices\n",
    "oos_idx = 0\n",
    "oos_labels = labels.pop(oos_idx)\n",
    "A, a = remove_vertices(undirected, indices=oos_idx, return_removed=True)\n",
    "\n",
    "# plot our SBM\n",
    "heatmap(A, title=f'2-block SBM (undirected), shape {A.shape}', inner_hier_labels=labels);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embedding (ASE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then generate an embedding with ASE, and we use its `transform` method to determine our best estimate for the latent position of the out-of-sample vertex."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate an embedding with ASE\n",
    "ase = ASE(n_components=2)\n",
    "X_hat_ase = ase.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_ase = ase.transform(a)\n",
    "w_ase"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plotting out-of-sample embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we plot the original latent positions as well as the out-of-sample vertices. Note that the out-of-sample vertices are near their expected latent positions despite not having been run through the original embedding.  \n",
    "In this plot, the stars are the out-of-sample latent positions, and the dots are the in-sample latent positions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_oos(X_hat, oos_vertices, labels, oos_labels, title):\n",
    "    # Plot the in-sample latent positions\n",
    "    plot = pairplot(X_hat, labels=labels, title=title)\n",
    "\n",
    "    # generate an out-of-sample dataframe\n",
    "    oos_vertices = np.atleast_2d(oos_vertices)\n",
    "    data = {'Type': oos_labels, \n",
    "          'Dimension 1': oos_vertices[:, 0], \n",
    "          'Dimension 2': oos_vertices[:, 1]}\n",
    "    oos_df = pd.DataFrame(data=data)\n",
    "    \n",
    "    # update plot with out-of-sample latent positions,\n",
    "    # plotting out-of-sample latent positions as stars\n",
    "    plot.data = oos_df\n",
    "    plot.hue_vals = oos_df[\"Type\"]\n",
    "    plot.map_offdiag(sns.scatterplot, s=500,\n",
    "                     marker=\"*\", edgecolor=\"black\")\n",
    "    plot.tight_layout()\n",
    "    return plot\n",
    "\n",
    "# Plot all latent positions\n",
    "plot_oos(X_hat_ase, w_ase, labels=labels, oos_labels=[0], title=\"ASE Out-of-Sample Embeddings (2-block SBM)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embeding (LSE)\n",
    "Similarly, we can also use Laplacian Spectral Embedding (LSE). We generate an embedding with its transform method to determine our best estimate for the latent position of the out-of-sample vertex."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate an embedding with ASE\n",
    "lse = LSE(n_components=2)\n",
    "X_hat_lse = lse.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_lse = lse.transform(a)\n",
    "w_lse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_oos(X_hat_lse, w_lse, labels=labels, oos_labels=[0], title=\"LSE Out-of-Sample Embeddings (2-block SBM)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Passing in multiple out-of-sample vertices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can pass a 2d numpy array into `transform`. The rows are the out-of-sample vertices, and the columns are their edges to the in-sample vertices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grab out-of-sample vertices\n",
    "labels = list(labels_)\n",
    "oos_idx = [0, -1]\n",
    "oos_labels = [labels.pop(i) for i in oos_idx]\n",
    "A, a = remove_vertices(undirected, indices=oos_idx, return_removed=True)\n",
    "\n",
    "# our out-of-sample array is m x n\n",
    "print(f\"a is {type(a)} with shape {a.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate an embedding with ASE\n",
    "ase = ASE(n_components=2)\n",
    "X_hat_ase = ase.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_ase = ase.transform(a)\n",
    "print(f\"The out-of-sample prediction output has dimensions {w_ase.shape}\\n\")\n",
    "\n",
    "# Plot all latent positions\n",
    "plot_oos(X_hat_ase, w_ase, labels, oos_labels=oos_labels,\n",
    "         title=\"ASE Out-of-Sample Embeddings (2-block SBM)\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate an embedding with LSE\n",
    "lse = LSE(n_components=2)\n",
    "X_hat_lse = lse.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_lse = lse.transform(a)\n",
    "print(f\"The out-of-sample prediction output has dimensions {w_lse.shape}\\n\")\n",
    "\n",
    "# Plot all latent positions\n",
    "plot_oos(X_hat_lse, w_lse, labels, oos_labels=oos_labels,\n",
    "         title=\"LSE Out-of-Sample Embeddings (2-block SBM)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Directed out-of-sample prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not all graphs are undirected. When finding out-of-sample latent positions for directed graphs, $A \\in \\textbf{R}^{n \\times n}$ is not symmetric. $A_{i,j}$ represents the edge from node $i$ to node $j$, whereas $A_{j, i}$ represents the edge from node $j$ to node $i$.\n",
    "\n",
    "To account for this, we pass a tuple (out_oos, in_oos) into the `transform` method. It then outputs a tuple of (out_latent_prediction, in_latent_prediction).  \n",
    "Here, \"out\" means \"edges from out-of-sample vertices to in-sample vertices\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate a directed SBM\n",
    "directed = sbm(2*[nodes_per_community], P, directed=True)\n",
    "oos_idx = [0, -1]\n",
    "\n",
    "# a is a tuple of (out_oos, in_oos)\n",
    "A, a = remove_vertices(directed, indices=oos_idx, return_removed=True)\n",
    "\n",
    "# Plot the new adjacency matrix\n",
    "heatmap(directed, title=f'2-block SBM (directed), shape {A.shape}');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit our directed graph\n",
    "X_hat_ase, Y_hat_ase = ase.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_ase = ase.transform(a)\n",
    "print(f\"output of `ase.transform(a)` is {type(w_ase)}\", \"\\n\")\n",
    "print(f\"out latent positions: \\n{w_ase[0]}\\n\")\n",
    "print(f\"in latent positions: \\n{w_ase[1]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting directed ASE latent predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_oos(X_hat_ase, w_ase[0], labels, oos_labels=oos_labels, title=\"ASE Out Latent Predictions\")\n",
    "plot_oos(Y_hat_ase, w_ase[1], labels, oos_labels=oos_labels, title=\"ASE In Latent Predictions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit our directed graph\n",
    "X_hat_lse, Y_hat_lse = lse.fit_transform(A)\n",
    "\n",
    "# predicted latent positions\n",
    "w_lse = lse.transform(a)\n",
    "print(f\"output of `ase.transform(a)` is {type(w_lse)}\", \"\\n\")\n",
    "print(f\"out latent positions: \\n{w_lse[0]}\\n\")\n",
    "print(f\"in latent positions: \\n{w_lse[1]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plotting directed LSE latent predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_oos(X_hat_lse, w_lse[0], labels, oos_labels=oos_labels, title=\"LSE Out Latent Predictions\")\n",
    "plot_oos(Y_hat_lse, w_lse[1], labels, oos_labels=oos_labels, title=\"LSE In Latent Predictions\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Weighted out-of-sample prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Weighted graphs work as well. Here, we generate a directed, weighted graph and estimate the latent positions for multiple out-of-sample vertices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate a weighted, directed SBM\n",
    "wt = [[normal, poisson],\n",
    "      [poisson, normal]]\n",
    "wtargs = [[dict(loc=3, scale=1), dict(lam=5)],\n",
    "          [dict(lam=5), dict(loc=3, scale=1)]]\n",
    "weighted = sbm(2*[nodes_per_community], P, wt=wt, wtargs=wtargs, directed=True)\n",
    "\n",
    "# Generate out-of-sample vertices\n",
    "oos_idx = [0, -1]\n",
    "A, a = remove_vertices(weighted, indices=oos_idx, return_removed=True)\n",
    "\n",
    "# Plot our weighted, directed SBM\n",
    "heatmap(A, title=f'2-block SBM (directed, weighted), shape {A.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed and transform\n",
    "X_hat_ase, Y_hat_ase = ase.fit_transform(A)\n",
    "w_ase = ase.transform(a)\n",
    "\n",
    "# Plot\n",
    "plot_oos(X_hat_ase, w_ase[0], labels, oos_labels=oos_labels, title=\"ASE Out Latent Predictions\")\n",
    "plot_oos(Y_hat_ase, w_ase[1],labels, oos_labels=oos_labels, title=\"ASE In Latent Predictions\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embed and transform\n",
    "X_hat_lse, Y_hat_lse = lse.fit_transform(A)\n",
    "w_lse = lse.transform(a)\n",
    "\n",
    "# Plot\n",
    "plot_oos(X_hat_lse, w_lse[0], labels, oos_labels=oos_labels, title=\"LSE Out Latent Predictions\")\n",
    "plot_oos(Y_hat_lse, w_lse[1],labels, oos_labels=oos_labels, title=\"LSE In Latent Predictions\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
